import numpy as np
import torch
import os
from ASVBaselineTool.feature_tools import *
import soundfile as sf
import torchaudio
import torch.nn as nn
import librosa

#FGSM攻击框架
class PGD():
    # 模型  标签  原始数据集路径  对抗样本保存的路径
    def __init__(self,model,labelPath,datasetPath,AdvsetPath,steps=80,eps=0.06):
        #获取label和fname
        with open(labelPath,"r") as f:
            Labels=f.readlines()
        self.AFnames=[item.split(" ")[0] for item in Labels]
        ALabels = [item.split(" ")[-1].strip() for item in Labels]
        #label meta
        self.LabelMeta={}
        for f,l in zip(self.AFnames,ALabels):
            if l == "genuine":
                l = torch.Tensor([0.0,1.0])
            elif l == "fake":
                l = torch.Tensor([1.0,0.0])
            self.LabelMeta[f]=l
        #保存的路径 和上面是同样索引
        self.model = model
        self.model.cuda()
        self.model.eval()

        self.datasetPath=datasetPath
        #保存对抗样本路径
        self.AdvSavePath=AdvsetPath
        self.steps=steps
        self.eps = eps

    def feature_select(self,feature,wav:torch.Tensor):
        if feature == "spec_2048":
            return get_SPEC_2048(wav)
        elif feature == "spec_1024":
            return get_SPEC_1024(wav)
        elif feature == "mfcc_40":
            return get_MFCC_40(wav)
        elif feature == "mfcc_80":
            return get_MFCC_80(wav)
        elif feature == "lfcc_70":
            return get_LFCC_70(wav)
        elif feature == "lfcc_60":
            return get_LFCC_60(wav)
        elif feature == "raw":
            wav = wav.reshape(1, -1)
            return wav

    def model_bn_to_eval(self):
        self.model.train()
        self.model.first_bn.eval()
        self.model.block0[0].bn2.eval()
        self.model.block1[0].bn1.eval()
        self.model.block1[0].bn2.eval()
        self.model.block2[0].bn1.eval()
        self.model.block2[0].bn2.eval()
        self.model.block3[0].bn1.eval()
        self.model.block3[0].bn2.eval()
        self.model.block4[0].bn1.eval()
        self.model.block4[0].bn2.eval()
        self.model.block5[0].bn1.eval()
        self.model.block5[0].bn2.eval()
        self.model.bn_before_gru.eval()

    def Attack(self,feature:str,alpha=0.0001,stop=0.1):
        print("attack alpha:"+str(alpha))
        num_f=len(self.AFnames)
        print("总样本：" + str(num_f))
        #攻击成功的样本的转移率
        num_as=0
        for item in self.AFnames:
            src_wav,sr=sf.read(os.path.join(self.datasetPath,item))
            src_wav=torch.Tensor(src_wav)
            if feature=="raw":
                #将bn层固定
                self.model_bn_to_eval()
            adv_wav=src_wav
            adv_feature=self.feature_select(feature,adv_wav)
            adv_feature=adv_feature.cuda()
            #计算预测概率
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            batch_x = self.model(adv_feature)
            init_pred = torch.max(batch_x,dim=1)[1].item()#计算最大预测
            #获取原始分类
            y=self.LabelMeta[item]
            targety = -1 * y + 1
            targety = targety.cuda()
            targety = targety.reshape(1, 2)
            init_y = y.max(0, keepdim=True)[1].item()
            #如果等于就说明本来就分类错误，直接跳过
            if init_pred != init_y:
                print("原始分类错误 pass")
                continue
            #分类正确则攻击
            else:
                # 循环攻击核心代码PGD
                perturbed_wav = adv_wav.detach().clone()
                lossfun = nn.CrossEntropyLoss()
                for i in range(self.steps):
                    #梯度清空并转移到CPU上
                    perturbed_wav=perturbed_wav.detach().cpu()
                    perturbed_wav.requires_grad = True
                    #提取特征并转移到cuda上输入网络计算预测
                    adv_feature=self.feature_select(feature,perturbed_wav)
                    adv_feature=adv_feature.cuda()
                    batch_x=self.model(adv_feature)
                    #print("yuce:",batch_x)
                    #计算损失并梯度反传
                    loss = lossfun(batch_x, targety)
                    #print("loss:",loss)
                    self.model.zero_grad()
                    loss.backward()
                    # 获取梯度更新扰动
                    data_grad = perturbed_wav.grad.data
                    sign_data_grad = data_grad.sign()
                    #生产无梯度样本
                    perturbed_wav = perturbed_wav.detach() - alpha * sign_data_grad
                    # 添加扰动保证才-eps,eps  且生成的音频范围在[-1,1]
                    delta = torch.clamp(perturbed_wav - adv_wav, min=-self.eps, max=self.eps)
                    perturbed_wav = torch.clamp(adv_wav + delta, min=-1, max=1)
                    if(loss.item()<stop):
                        break
                #攻击后重新计算特征送入网络判断
                perturbed_wav = perturbed_wav.detach().cpu()
                perturbed_wav.requires_grad=True
                temp_adv_feature=self.feature_select(feature,wav=perturbed_wav)
                temp_adv_feature=temp_adv_feature.to(device)
                temp_adv_feature.cuda()
                temp_batch_x = self.model(temp_adv_feature)
                temp_init_pred = temp_batch_x.max(1)[1].item()
                #如果攻击失败
                if temp_init_pred == init_y:
                    print("攻击失败 pass")
                    continue
                else:
                    sp=os.path.join(self.AdvSavePath, item)
                    print("攻击成功并保存值路径:"+sp)
                    num_as+=1
                    perturbed_wav = perturbed_wav.detach().cpu().numpy()

                    sf.write(sp, perturbed_wav , 16000)
                    #x,sr=sf.read(sp)
                    # if torch.max(x[0]-perturbed_wav)>0.0001:
                    #     print(torch.max(x[0]-perturbed_wav))
                    #     print("差距较大")
        print("总样本：" + str(num_f))
        print("攻击成功样本数量:"+str(num_as))
        print("攻击成功率:"+str(num_as/num_f))




















